"""
This file contains several utility functions for working with environment
wrappers provided by the repository, and with environment metadata saved
in dataset files.
"""
from copy import deepcopy
import robomimic.envs.env_base as EB
from robomimic.utils.log_utils import log_warning


def get_env_class(env_meta=None, env_type=None, env=None):
    """
    Return env class from either env_meta, env_type, or env.
    Note the use of lazy imports - this ensures that modules are only
    imported when the corresponding env type is requested. This can
    be useful in practice. For example, a training run that only
    requires access to gym environments should not need to import
    robosuite.

    Args:
        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains 3 keys:

                :`'env_name'`: name of environment
                :`'type'`: type of environment, should be a value in EB.EnvType
                :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor

        env_type (int): the type of environment, which determines the env class that will
            be instantiated. Should be a value in EB.EnvType.

        env (instance of EB.EnvBase): environment instance
    """
    env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
    if env_type == EB.EnvType.ROBOSUITE_TYPE:
        from robomimic.envs.env_robosuite import EnvRobosuite
        return EnvRobosuite
    elif env_type == EB.EnvType.GYM_TYPE:
        from robomimic.envs.env_gym import EnvGym
        return EnvGym
    elif env_type == EB.EnvType.IG_MOMART_TYPE:
        from robomimic.envs.env_ig_momart import EnvGibsonMOMART
        return EnvGibsonMOMART
    elif env_type == EB.EnvType.REAL_TYPE:
        from robomimic.envs.env_real_panda import EnvRealPanda
        return EnvRealPanda
    elif env_type == EB.EnvType.GPRS_REAL_TYPE:
        from robomimic.envs.env_real_panda_gprs import EnvRealPandaGPRS
        return EnvRealPandaGPRS
    raise Exception("code should never reach this point")


def get_env_type(env_meta=None, env_type=None, env=None):
    """
    Helper function to get env_type from a variety of inputs.

    Args:
        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains 3 keys:

                :`'env_name'`: name of environment
                :`'type'`: type of environment, should be a value in EB.EnvType
                :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor

        env_type (int): the type of environment, which determines the env class that will
            be instantiated. Should be a value in EB.EnvType.

        env (instance of EB.EnvBase): environment instance
    """
    checks = [(env_meta is not None), (env_type is not None), (env is not None)]
    assert sum(checks) == 1, "should provide only one of env_meta, env_type, env"
    if env_meta is not None:
        env_type = env_meta["type"]
    elif env is not None:
        env_type = env.type
    return env_type


def check_env_type(type_to_check, env_meta=None, env_type=None, env=None):
    """
    Checks whether the passed env_meta, env_type, or env is of type @type_to_check.
    Type corresponds to EB.EnvType.

    Args:
        type_to_check (int): type to check equality against

        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains 3 keys:

                :`'env_name'`: name of environment
                :`'type'`: type of environment, should be a value in EB.EnvType
                :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor

        env_type (int): the type of environment, which determines the env class that will
            be instantiated. Should be a value in EB.EnvType.

        env (instance of EB.EnvBase): environment instance
    """
    env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
    return (env_type == type_to_check)


def check_env_version(env, env_meta):
    """
    Checks whether the passed env and env_meta dictionary having matching environment versions.
    Logs warning if cannot find version or versions do not match.

    Args:
        env (instance of EB.EnvBase): environment instance

        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains following key:

                :`'env_version'`: environment version, type str
    """
    env_system_version = env.version
    env_meta_version = env_meta.get("env_version", None)

    if env_meta_version is None:
        log_warning(
            "No environment version found in dataset!"\
            "\nCannot verify if dataset and installed environment versions match"\
        )
    elif env_system_version != env_meta_version:
        log_warning(
            "Dataset and installed environment version mismatch!"\
            "\nDataset environment version: {meta}"\
            "\nInstalled environment version: {sys}".format(
                sys=env_system_version,
                meta=env_meta_version,
            )
        )


def is_robosuite_env(env_meta=None, env_type=None, env=None):
    """
    Determines whether the environment is a robosuite environment. Accepts
    either env_meta, env_type, or env.
    """
    return check_env_type(type_to_check=EB.EnvType.ROBOSUITE_TYPE, env_meta=env_meta, env_type=env_type, env=env)


def is_simpler_env(env_meta=None, env_type=None, env=None):
    return False


def is_simpler_ov_env(env_meta=None, env_type=None, env=None):
    return False


def is_factory_env(env_meta=None, env_type=None, env=None):
    return False


def is_furniture_sim_env(env_meta=None, env_type=None, env=None):
    return False


def is_real_robot_env(env_meta=None, env_type=None, env=None):
    """
    Determines whether the environment is a real robot environment. Accepts
    either env_meta, env_type, or env.
    """
    return check_env_type(type_to_check=EB.EnvType.REAL_TYPE, env_meta=env_meta, env_type=env_type, env=env)


def is_real_robot_gprs_env(env_meta=None, env_type=None, env=None):
    """
    Determines whether the environment is a real robot environment. Accepts
    either env_meta, env_type, or env.
    """
    return check_env_type(type_to_check=EB.EnvType.GPRS_REAL_TYPE, env_meta=env_meta, env_type=env_type, env=env)


def create_env(
    env_type,
    env_name,
    env_class=None,
    render=False, 
    render_offscreen=False, 
    use_image_obs=False, 
    use_depth_obs=False,
    **kwargs,
):
    """
    Create environment.

    Args:
        env_type (int): the type of environment, which determines the env class that will
            be instantiated. Should be a value in EB.EnvType.

        env_name (str): name of environment

        render (bool): if True, environment supports on-screen rendering

        render_offscreen (bool): if True, environment supports off-screen rendering. This
            is forced to be True if @use_image_obs is True.

        use_image_obs (bool): if True, environment is expected to render rgb image observations
            on every env.step call. Set this to False for efficiency reasons, if image
            observations are not required.

        use_depth_obs (bool): if True, environment is expected to render depth image observations
            on every env.step call. Set this to False for efficiency reasons, if depth
            observations are not required.
    """

    # note: pass @postprocess_visual_obs True, to make sure images are processed for network inputs
    if env_class is None:
        env_class = get_env_class(env_type=env_type)
    env = env_class(
        env_name=env_name, 
        render=render, 
        render_offscreen=render_offscreen, 
        use_image_obs=use_image_obs,
        postprocess_visual_obs=True,
        **kwargs,
    )
    print("Created environment with name {}".format(env_name))
    print("Action size is {}".format(env.action_dimension))
    return env


def create_env_from_metadata(
    env_meta,
    env_name=None,
    env_class=None,  
    render=False, 
    render_offscreen=False, 
    use_image_obs=False, 
    use_depth_obs=False, 
):
    """
    Create environment.

    Args:
        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains 3 keys:

                :`'env_name'`: name of environment
                :`'type'`: type of environment, should be a value in EB.EnvType
                :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor

        env_name (str): name of environment. Only needs to be provided if making a different
            environment from the one in @env_meta.

        render (bool): if True, environment supports on-screen rendering

        render_offscreen (bool): if True, environment supports off-screen rendering. This
            is forced to be True if @use_image_obs is True.

        use_image_obs (bool): if True, environment is expected to render rgb image observations
            on every env.step call. Set this to False for efficiency reasons, if image
            observations are not required.

        use_depth_obs (bool): if True, environment is expected to render depth image observations
            on every env.step call. Set this to False for efficiency reasons, if depth
            observations are not required.
    """
    if env_name is None:
        env_name = env_meta["env_name"]
    env_type = get_env_type(env_meta=env_meta)
    env_kwargs = env_meta["env_kwargs"]
    env_kwargs.pop("use_image_obs", None)
    env_kwargs.pop("use_depth_obs", None)

    env = create_env(
        env_type=env_type,
        env_name=env_name,  
        env_class=env_class, 
        render=render, 
        render_offscreen=render_offscreen, 
        use_image_obs=use_image_obs, 
        use_depth_obs=use_depth_obs,
        **env_kwargs,
    )
    check_env_version(env, env_meta)
    return env


def create_env_for_data_processing(
    env_meta,
    camera_names, 
    camera_height, 
    camera_width, 
    reward_shaping,
    env_class=None,
    render=None, 
    render_offscreen=None, 
    use_image_obs=None, 
    use_depth_obs=None, 
):
    """
    Creates environment for processing dataset observations and rewards.

    Args:
        env_meta (dict): environment metadata, which should be loaded from demonstration
            hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
            @FileUtils.env_from_checkpoint). Contains 3 keys:

                :`'env_name'`: name of environment
                :`'type'`: type of environment, should be a value in EB.EnvType
                :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor

        camera_names (list of st): list of camera names that correspond to image observations

        camera_height (int): camera height for all cameras

        camera_width (int): camera width for all cameras

        reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards

        render (bool or None): optionally override rendering behavior

        render_offscreen (bool or None): optionally override rendering behavior

        use_image_obs (bool or None): optionally override rendering behavior

        use_depth_obs (bool or None): optionally override rendering behavior
    """
    env_name = env_meta["env_name"]
    env_type = get_env_type(env_meta=env_meta)
    env_kwargs = env_meta["env_kwargs"]
    if env_class is None:
        render_ov = False if render is None else render
        env_class = get_env_class(env_type=env_type)

    # remove possibly redundant values in kwargs
    env_kwargs = deepcopy(env_kwargs)
    env_kwargs.pop("env_name", None)
    env_kwargs.pop("camera_names", None)
    env_kwargs.pop("camera_height", None)
    env_kwargs.pop("camera_width", None)
    env_kwargs.pop("reward_shaping", None)
    env_kwargs.pop("render", None)
    env_kwargs.pop("render_offscreen", None)
    env_kwargs.pop("use_image_obs", None)
    env_kwargs.pop("use_depth_obs", None)

    env = env_class.create_for_data_processing(
        env_name=env_name, 
        camera_names=camera_names, 
        camera_height=camera_height, 
        camera_width=camera_width, 
        reward_shaping=reward_shaping, 
        render=render, 
        render_offscreen=render_offscreen, 
        use_image_obs=use_image_obs, 
        use_depth_obs=use_depth_obs,
        **env_kwargs,
    )
    check_env_version(env, env_meta)
    return env


def set_env_specific_obs_processing(env_meta=None, env_type=None, env=None):
    """
    Sets env-specific observation processing. As an example, robosuite depth observations
    correspond to raw depth and should not be normalized by default, while default depth
    processing normalizes and clips all values to [0, 1]. As another example, depth
    observations on the real robot are uint16 and will be converted to float during processing.
    """
    if is_robosuite_env(env_meta=env_meta, env_type=env_type, env=env):
        from robomimic.utils.obs_utils import DepthModality, process_frame, unprocess_frame
        DepthModality.set_obs_processor(processor=(
            lambda obs: process_frame(frame=obs, channel_dim=1, scale=None)
        ))
        DepthModality.set_obs_unprocessor(unprocessor=(
            lambda obs: unprocess_frame(frame=obs, channel_dim=1, scale=None)
        ))
    elif is_real_robot_gprs_env(env_meta=env_meta, env_type=env_type, env=env):
        from robomimic.envs.env_real_panda_gprs import get_depth_scale
        from robomimic.utils.obs_utils import DepthModality, batch_image_hwc_to_chw, batch_image_chw_to_hwc
        from robomimic.utils.tensor_utils import to_float, to_uint16
        
        # NOTE: assuming that depth scales for front and wrist camera are about the same right now...
        scale = get_depth_scale(camera_name="front")

        def new_process_frame(frame):
            assert (frame.shape[-1] == 1)
            frame = to_float(frame)
            frame *= scale
            return batch_image_hwc_to_chw(frame)

        def new_unprocess_frame(frame):
            raise Exception("real robot depth unprocessor is wrong since torch does not support uint16")
            assert frame.shape[-3] == 1 # check for channel dimension
            frame = batch_image_chw_to_hwc(frame)
            frame /= scale
            return to_uint16(frame)

        DepthModality.set_obs_processor(processor=new_process_frame)
        DepthModality.set_obs_unprocessor(unprocessor=new_unprocess_frame)


def wrap_env_from_config(env, config):
    """
    Wraps environment using the provided Config object to determine which wrappers
    to use (if any).
    """
    if config.train.frame_stack > 1:
        from robomimic.envs.wrappers import FrameStackWrapper
        env = FrameStackWrapper(env, num_frames=config.train.frame_stack)

    return env
